import matplotlib.pyplot as plt
import numpy as np
import os
import itertools
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.nn import Module
import time

import brevitas.nn as qnn
from brevitas.quant import Int8Bias as BiasQuant
from brevitas import config

## define the general training steps
def train(model, device, dataloader, criterion, optimizer, epoch):
    
    print('\n')
    print(f'Training Epoch: {epoch}')
    
    model.train()
    running_loss = 0.0
    correct = 0.0
    for batch_idx, (data) in enumerate(dataloader, 1):
        # zero the parameter gradients
        optimizer.zero_grad()

        # send input and label data to the device
        inputs, labels = data[0].to(device), data[1].to(device)
        
        batch_size = inputs.size(0)
        
        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        #metrics and printouts
        running_loss += loss.item()
        pred = outputs.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(labels.view_as(pred)).sum().item()
        
        if batch_idx % 20 == 0:
            print(f'Epoch {epoch} - Train Loss: {running_loss/(batch_idx*batch_size) :.6f} - Train Top1 Accuracy: {100.*correct/(batch_idx*batch_size) :.6f}')
    
    print(f'Training Epoch: {epoch} Complete')
    print('\n')
    
    return running_loss/len(dataloader.dataset), 100.*correct/len(dataloader.dataset)

## define general valuation steps:
def evaluate(model, device, dataloader, criterion, epoch):
    
    print('\n')
    print(f'Eval Epoch: {epoch}')
    
    model.eval()
    running_loss = 0.0
    correct = 0.0
    with torch.no_grad():
        for batch_idx, (data) in enumerate(dataloader, 1):

            # send input and label data to the device
            inputs, labels = data[0].to(device), data[1].to(device)

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            #metrics and printouts
            running_loss += loss.item()
            pred = outputs.argmax(dim=1, keepdim=True)  # get the index of the max log-probability for top-1
            correct += pred.eq(labels.view_as(pred)).sum().item()

            
    print(f'Epoch {epoch} - Eval Loss: {running_loss/len(dataloader.dataset) :.6f} - Eval Top1 Accuracy: {100.*correct/len(dataloader.dataset) :.6f}')
    
    print(f'Eval Epoch: {epoch} Complete')
    print('\n')
            
    return running_loss/len(dataloader.dataset), 100.*correct/len(dataloader.dataset)

def quant_accumulator(model, layer_filter=None):
    def layer_filt(nm):
        if layer_filter is not None:
            return layer_filter not in name
        else:
            return True
    data = []
    for name, module in model.named_modules():
        if (isinstance(module, nn.Linear) or (isinstance(module, nn.Conv2d))) and (layer_filt(name)):
            for n, p in list(module.named_parameters()):
                if n.endswith('weight'):
                    
                    data.append(module.quant_weight().value)

                else:
                    p.collect = False
            continue
        for p in list(module.parameters()):
            if p.requires_grad:
                p.collect = False

    
    return data

def float_accumulator(model, layer_filter=None):
    def layer_filt(nm):
        if layer_filter is not None:
            return layer_filter not in name
        else:
            return True
    data = []
    for name, module in model.named_modules():
        if (isinstance(module, nn.Linear) or (isinstance(module, nn.Conv2d))) and (layer_filt(name)):
            for n, p in list(module.named_parameters()):
                if n.endswith('weight'):
                    
                    data.append(p.data)

                else:
                    p.collect = False
            continue
        for p in list(module.parameters()):
            if p.requires_grad:
                p.collect = False

    
    return data